This analysis uses ABCD Release 4.0
rm(list=ls())
gc()
library(tidyverse)
library(ggthemes)
source(paste0(scriptfold,"stacking_gfactor_modelling/r_functions.R"))
set up parallel
# parallel for ubuntu
doParallel::registerDoParallel(cores=15)
## this one works for ubuntu but slow
#library(doFuture)
#registerDoFuture()
#plan(multicore(workers = 30))
### parallel for windows
#library(doFuture)
#registerDoFuture()
#plan(multisession(workers = 30))
index_vec <- seq(1:21)
process_shap_results <- function(vec_input=1,time_input="baseline"){
file_name <- paste0("random_forest_",time_input,"_results_shap_",vec_input,".RDS")
output_list <- readRDS(paste0(scriptfold,"stacking_gfactor_modelling/rf_shap_results/",file_name))
return(output_list)
}
rf_shap_baseline <- map(index_vec, ~process_shap_results(vec_input=.))
### extract the site information
rf_shap_baseline_test_data <- map(rf_shap_baseline,"test_data")
rf_shap_baseline_train_data <- map(rf_shap_baseline,"train_data")
rf_shap_baseline_shap <- map(rf_shap_baseline,"model_shap")
site_extract <- function(data_input){
site <- unique(data_input[["SITE_ID_L"]])
return(site)
}
rf_shap_baseline_site <- map(rf_shap_baseline_test_data,~site_extract(data_input = .))%>% do.call(rbind,.)
rf_shap_followup <- map(index_vec, ~process_shap_results(vec_input=.,time_input="followup"))
rf_shap_followup_test_data <- map(rf_shap_followup,"test_data")
rf_shap_followup_train_data <- map(rf_shap_followup,"train_data")
rf_shap_followup_shap <- map(rf_shap_followup,"model_shap")
rf_shap_followup_site <- map(rf_shap_followup_test_data,~site_extract(data_input = .))%>% do.call(rbind,.)
Setting up the continuous color palette:
brain_plot_color <- c("#240105","#380208","#450109","#52010b","#69010e","#8a0313","#990315","#b00419","#B2182B","#D6604D","#F4A582","#FDDBC7","#D1E5F0","#92C5DE","#4393C3","#2166AC","#0358ad","#02509e","#01488f","#013870","#002c59","#012a54","#012142","#001933")
Setting up vectors to choose the right part of the output
subj_info <- c("SUBJECTKEY","SITE_ID_L","EVENTNAME")
subj_resp_noid <- c("gfactor","SITE_ID_L","EVENTNAME")
Load in the names of plotting titles for all the modalities.
plotting_names <- readxl::read_excel(paste0(scriptfold,"Common_psy_gene_brain_all/CommonalityPlotingNames.xlsx"))
plotting_names_tidy <- plotting_names %>%
mutate(tidied_names = str_remove_all(Original_name,"data"))
plotting_names_tidy <- plotting_names_tidy %>%
mutate(tidied_names = str_remove_all(tidied_names,"ROI\\_"))%>%
select(-Original_name)%>%
rename(modality_names = tidied_names)
# The input of this function should have exactly 4 variables
# SUBJECTKEY that is subject ID
# The names of modality ie MID, smri, rsmri, nback, sst and dti
# shapley values
# feature values
shapley_summary_plot <- function(data_input, list_val_input){
shap_plot <- data_input%>%
# filter(var_names %in% vars_keep)%>%
mutate(plotting_name = fct_reorder(plotting_name, shapley_values,.fun = "max"))%>%
ggplot(aes(x = plotting_name, y = shapley_values, color = feature_values)) +
coord_flip(ylim = range(data_input$shapley_values)*1.1) +
geom_hline(yintercept = 0) + # the y-axis beneath
# sina plot:
ggforce::geom_sina(method = "counts", maxwidth = 0.7, alpha = 0.7)+
scale_color_gradientn(colours=brain_plot_color ,na.value ="gray50",
breaks=unlist(list_val_input), labels=c('min','med','max'),
guide = guide_colorbar(direction = "horizontal",barwidth = 12, barheight = 0.3))+
theme_bw() +
theme(axis.line.y = element_blank(),
axis.ticks.y = element_blank(), # remove axis line
legend.position="bottom",
legend.direction = 'horizontal',
legend.title=element_text(size=20),
legend.text=element_text(size=20),
legend.box = "horizontal",
axis.title.x= element_text(size = 20),
axis.text.y = element_text(size = 15),
axis.text.x = element_text(size = 20)) +
labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value ")
return(shap_plot)
}
### this function read in the shap output file and plot the summary plot by each site.
shapley_data_process_plotting <- function(shapley_input, data_input){
##processing shapley output
shapley_large_wider <- shapley_input%>%tibble::as_tibble()%>%
dplyr::select(ends_with("large"))%>%
mutate(SUBJECTKEY=data_input$SUBJECTKEY)
shapley_large<- shapley_large_wider%>%
pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "shapley_values") %>%
mutate(modality_names = str_remove_all(.data[["modality_names"]],"large"))
### get the data together with plotting names
shapley_large_with_names <- left_join(shapley_large,plotting_names_tidy,by ="modality_names")
data_large <- data_input%>%tibble::as_tibble()%>%
dplyr::select(ends_with("large"),"SUBJECTKEY")%>%
pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "feature_values")%>%
mutate(modality_names = str_remove_all(.data[["modality_names"]],"large"))
shapley_small_wider <- shapley_input%>%tibble::as_tibble()%>%
dplyr::select(ends_with("small"))%>%
mutate(SUBJECTKEY=data_input$SUBJECTKEY)
shapley_small <- shapley_small_wider%>%
pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "shapley_values")%>%
mutate(modality_names = str_remove_all(.data[["modality_names"]],"small"))
### get the data together with plotting names
shapley_small_with_names <- left_join(shapley_small,plotting_names_tidy,by ="modality_names")
data_small <- data_input%>%tibble::as_tibble()%>%
dplyr::select(ends_with("small"),"SUBJECTKEY")%>%
pivot_longer(-SUBJECTKEY, names_to = "modality_names", values_to = "feature_values")%>%
mutate(modality_names = str_remove_all(.data[["modality_names"]],"small"))
shapley_all <- bind_rows(shapley_large_with_names,shapley_small_with_names)
data_all <- bind_rows(data_small,data_small)
shapley_summary_tibble <- plyr::join_all(list(shapley_all,data_all),
by=c("SUBJECTKEY","modality_names"),type="full")%>%
distinct()
##shapley summary plot
shapley_summary_tibble_NA <- shapley_summary_tibble %>%
naniar::replace_with_na(replace = list(feature_values = c(-1000,1000.000000000)))
shapley_summary_tibble_0 <- shapley_summary_tibble
shapley_summary_tibble_0$feature_values[which(shapley_summary_tibble_0$feature_values == 1000)] <- 0
shapley_summary_tibble_0$feature_values[which(shapley_summary_tibble_0$feature_values == -1000)] <- 0
# list of functions to calculate the values where you want your breaks
myfuns <- list(min, median, max)
# use this list to make a list of your breaks
list_val <- lapply(myfuns, function(f) f(shapley_summary_tibble$feature_values))
list_val_0 <- lapply(myfuns, function(f) f(shapley_summary_tibble_0$feature_values))
shapley_summary_plot_1k <- shapley_summary_plot(data_input = shapley_summary_tibble,
list_val_input = list_val)
shapley_summary_plot_na <- shapley_summary_plot(data_input = shapley_summary_tibble_NA,
list_val_input = list_val_0)
### plot the sum of the absolute shap values
shapley_wider_names <- colnames(shapley_large_wider)
shapley_wider_names_cleaned <- shapley_wider_names %>% str_remove_all("large")
shapley_large_wider_new_names <-shapley_large_wider
names(shapley_large_wider_new_names) <- shapley_wider_names_cleaned
shapley_small_wider_new_names <-shapley_small_wider
names(shapley_small_wider_new_names) <- shapley_wider_names_cleaned
shapley_vi <- bind_rows(shapley_small_wider_new_names,shapley_large_wider_new_names)%>%
select(-"SUBJECTKEY") %>%
abs()%>% colSums()
shapley_vi_tibble <- tibble(modality_names = names(shapley_vi),
shapley_vi = shapley_vi)
shapley_vi_tibble_with_names <- left_join(shapley_vi_tibble,plotting_names_tidy,by = "modality_names")
shapley_vi_plot <- shapley_vi_tibble_with_names %>%
mutate(plotting_name = fct_reorder(plotting_name, shapley_vi,.fun = "max"))%>%
ggplot(aes(x = plotting_name, y = shapley_vi))+
geom_bar(stat = "identity")+
coord_flip()+
labs(y = "|SHAP| variable importance", x = "") +
theme(axis.title.x= element_text(size = 20),
axis.title.y= element_text(size = 20),
axis.text.y = element_text(size = 15),
axis.text.x = element_text(size = 20))
return(list(summary_plot_na = shapley_summary_plot_na,
summary_plot_1k = shapley_summary_plot_1k,
importance_plot = shapley_vi_plot,
shapley_vi_tibble_with_names=shapley_vi_tibble_with_names
))
}
trial code: one output plot for one site
plots_one_site <-shapley_data_process_plotting(shapley_input=rf_shap_baseline_shap[[1]],
data_input=rf_shap_baseline_train_data[[1]])
plots_one_site
## $summary_plot_na
##
## $summary_plot_1k
##
## $importance_plot
##
## $shapley_vi_tibble_with_names
## # A tibble: 45 × 3
## modality_names shapley_vi plotting_name
## <chr> <dbl> <chr>
## 1 DTI_ 218. DTI
## 2 rsmri_within_avg_ 352. rsfMRI cortical FC
## 3 smri_T2_mean_total_ 89.0 T2 summations
## 4 smri_T1_mean_total_ 45.2 T1 summations
## 5 Normalised_T2_ 121. T2 normalised intensity
## 6 Avg_T2_Gray_ 310. T2 gray matter avg intensity
## 7 Avg_T2_White_ 171. T2 white matter avg intensity
## 8 Normalised_T1_ 171. T1 normalised intensity
## 9 Avg_T1_Gray_ 413. T1 gray matter avg intensity
## 10 Avg_T1_White_ 263. T1 white matter avg intensity
## # ℹ 35 more rows
Map the output across all sites
shap_plots_baseline <- map2(.x =rf_shap_baseline_shap,
.y=rf_shap_baseline_train_data,
~shapley_data_process_plotting(shapley_input=.x,data_input=.y))
shap_plots_followup <- map2(.x =rf_shap_followup_shap,
.y=rf_shap_followup_train_data,
~shapley_data_process_plotting(shapley_input=.x,data_input=.y))
Extract the sum values
Compute the means across sites and sd across sites.
shap_baseline_sum <- map(shap_plots_baseline,"shapley_vi_tibble_with_names")
shap_followup_sum <- map(shap_plots_followup,"shapley_vi_tibble_with_names")
names(shap_baseline_sum) <- rf_shap_baseline_site
names(shap_followup_sum) <- rf_shap_followup_site
shap_sum_change_names <- function(data_input,site_input){
names(data_input) <- c("modality_names", site_input, "plotting_name" )
return(data_input)
}
### change names according to sites
shap_baseline_sum_site <- map2(.x =shap_baseline_sum,
.y= rf_shap_baseline_site,~shap_sum_change_names(data_input = .x,site_input = .y))
shap_baseline_cross_site <- plyr::join_all(shap_baseline_sum_site, by = c("modality_names","plotting_name"))
### compute row mean and sd
shap_baseline_cross_site_num <- shap_baseline_cross_site%>%
select(-all_of(c("modality_names","plotting_name" )))
mean_sum_shap_baseline <- rowMeans(shap_baseline_cross_site_num)
sd_sum_shap_baseline <- apply(shap_baseline_cross_site_num, 1, sd)
shap_baseline_cross_site <- shap_baseline_cross_site %>%
mutate(mean_cross_site = mean_sum_shap_baseline, sd_cross_site = sd_sum_shap_baseline)
### change names according to sites
shap_followup_sum_site <- map2(.x =shap_followup_sum,
.y= rf_shap_followup_site,~shap_sum_change_names(data_input = .x,site_input = .y))
shap_followup_cross_site <- plyr::join_all(shap_followup_sum_site, by = c("modality_names","plotting_name"))
### compute row mean and sd
shap_followup_cross_site_num <- shap_followup_cross_site%>%
select(-all_of(c("modality_names","plotting_name" )))
mean_sum_shap_followup <- rowMeans(shap_followup_cross_site_num)
sd_sum_shap_followup <- apply(shap_followup_cross_site_num, 1, sd)
shap_followup_cross_site <- shap_followup_cross_site %>%
mutate(mean_cross_site = mean_sum_shap_followup, sd_cross_site = sd_sum_shap_followup)
Plotting the baseline and followup shap values individually based on the magnitude.
### function for summary plotting
shap_sum_cross_site <- function(data_input,time_input = "baseline"){
data_select <- data_input %>%
select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))
shapley_vi_plot <- data_select %>%
mutate(plotting_name = fct_reorder(plotting_name, mean_cross_site,.fun = "max"))%>%
ggplot(aes(x = plotting_name, y = mean_cross_site))+
geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name,
ymin=mean_cross_site-sd_cross_site,
ymax=mean_cross_site+sd_cross_site),
width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
coord_flip()+
labs(y =paste0( "|SHAP| variable importance at ", time_input) , x = "") +
theme_few()+
theme(axis.title.x= element_text(size = 20),
axis.title.y= element_text(size = 20),
axis.text.y = element_text(size = 15),
axis.text.x = element_text(size = 15))
return(shapley_vi_plot)
}
shap_sum_plot_baseline <- shap_sum_cross_site(data_input=shap_baseline_cross_site)
shap_sum_plot_baseline
shap_sum_plot_followup <- shap_sum_cross_site(data_input=shap_followup_cross_site, time_input = "followup")
shap_sum_plot_followup
Plot the summary plots. The order of baseline is based on its magnitude and the followup is based on the order of baseline.
shap_baseline_cross_site_select <- shap_baseline_cross_site%>%
select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))
shapley_vi_plot_baseline <- shap_baseline_cross_site_select %>%
mutate(plotting_name = fct_reorder(plotting_name, mean_cross_site,.fun = "max"))%>%
ggplot(aes(x = plotting_name, y = mean_cross_site))+
geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name,
ymin=mean_cross_site-sd_cross_site,
ymax=mean_cross_site+sd_cross_site),
width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
coord_flip()+
labs(y ="" , x = "") +
theme_few()+
theme(axis.title.x= element_text(size = 20),
axis.title.y= element_text(size = 20),
axis.text.y = element_text(size = 15),
axis.text.x = element_text(angle = 90,size = 15,vjust=-0.2))
#shapley_vi_plot_baseline
# get the order of sets of brain scan features in baseline based on the magnitude
shap_baseline_cross_site_select_reordered <- shap_baseline_cross_site_select %>%
arrange(mean_cross_site)
shap_baseline_cross_site_select_reordered_names <- shap_baseline_cross_site_select_reordered$plotting_name
shap_followup_cross_site_select <- shap_followup_cross_site%>%
select(all_of(c("mean_cross_site","sd_cross_site","plotting_name")))
### reorder followup based on baseline reordered as described above
shap_followup_cross_site_select_reordered <- shap_followup_cross_site_select %>%
mutate(plotting_name = as.factor(plotting_name))%>%
mutate(plotting_name = factor(plotting_name,
levels =shap_baseline_cross_site_select_reordered_names))
shapley_vi_plot_followup <- shap_followup_cross_site_select_reordered %>%
ggplot(aes(x = plotting_name, y = mean_cross_site))+
geom_bar(stat = "identity",fill="gray30",alpha = 0.7)+
geom_errorbar( aes(x=plotting_name,
ymin=mean_cross_site-sd_cross_site,
ymax=mean_cross_site+sd_cross_site),
width=0.4, colour="black", alpha=0.9, linewidth=1.3)+
coord_flip()+
labs(y ="" , x = "") +
theme_few()+
theme(axis.title.x= element_text(size = 20),
axis.title.y= element_blank(),
axis.text.y = element_blank(),
axis.text.x = element_text(angle = 90,size = 20))
align figures
#shapley_vi_plot_followup
shapley_vi_label_baseline <- shapley_vi_plot_baseline%>%
ggpubr::annotate_figure(top = ggpubr::text_grob("Baseline",size=20,hjust=-1))
shapley_vi_label_followup <- shapley_vi_plot_followup%>%
ggpubr::annotate_figure(top = ggpubr::text_grob("Followup",size=20,hjust=1))
shapley_vi_plot_aligned <- ggpubr::ggarrange(shapley_vi_label_baseline,
shapley_vi_label_followup, widths = c(3, 1),heights = c(1,1),
ncol = 2, nrow = 1
#,common.legend=TRUE,legend = "top"
)
labelled_shapley_vi_plot <- ggpubr::annotate_figure(shapley_vi_plot_aligned,
left= ggpubr::text_grob("Sets of Features from the Brain",size=25,rot=90),
bottom = ggpubr::text_grob("Mean |SHAP| Values",size=25,hjust = -0.3),
top = ggpubr::text_grob("Feature importance of Each Set of Features \nfrom the Brain in Predicting Cognitive Abilities",size=30, face = "bold"))
labelled_shapley_vi_plot